# egnn_net.py
import torch
from torch import nn
import torch.optim as optim
import numpy as np
import networkx as nx
import dgl
from dgl.nn.pytorch.conv import EGNNConv

def create_dgl_graph(G, features, directed=False):

    G_line = nx.line_graph(G)
    edge_map = {e: i for i, e in enumerate(G_line.nodes())}

    # build bidir / self‐loop edges exactly as before...
    if directed:
        num_edges = G_line.number_of_nodes() + G_line.number_of_edges()
        src = np.zeros(num_edges)
        dst = np.zeros(num_edges)
    else:
        num_edges = G_line.number_of_nodes() + 2*G_line.number_of_edges()
        src = np.zeros(num_edges)
        dst = np.zeros(num_edges)

    idx = 0
    for u,v in G_line.edges():
        i, j = edge_map[u], edge_map[v]
        src[idx], dst[idx] = i, j; idx += 1
        if not directed:
            src[idx], dst[idx] = j, i; idx += 1

    # self loops
    for v in G_line.nodes():
        i = edge_map[v]
        src[idx], dst[idx] = i, i
        idx += 1

    g = dgl.graph((torch.tensor(src, dtype=torch.long),
                  torch.tensor(dst, dtype=torch.long)))
    # node‐features: one feature‐vector per original edge
    n_feat = next(iter(features.values())).shape[0]
    feat = torch.zeros((G.number_of_edges(), n_feat))
    for e, i in edge_map.items():
        feat[i] = torch.FloatTensor(features[e])
    g.ndata['feat'] = feat

    return (g.cuda() if torch.cuda.is_available() else g), edge_map


def get_feat_ids(G, flows, edge_map):

    ids = [edge_map[e] for e in G.edges() if e in flows]
    t = torch.tensor(ids,
                     device='cuda:0' if torch.cuda.is_available() else 'cpu')
    return t


class EGNNNet(nn.Module):

    def __init__(self, in_feats, hidden_feats, n_steps, lr, early_stop=10, activation=torch.sigmoid):
        super().__init__()
        # conv1: in_feats → hidden_feats
        self.conv1 = EGNNConv(in_feats, hidden_feats, hidden_feats, edge_feat_size=0)
        # conv2: hidden_feats → 1
        self.conv2 = EGNNConv(hidden_feats, hidden_feats, 1, edge_feat_size=0)

        self.n_steps = n_steps
        self.early_stop = early_stop
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        self.activation = activation
        self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
        self.to(self.device)

    def forward(self, g, feats):
        # dummy 1‐D coordinate per node
        coord = torch.zeros(feats.size(0), 1, device=self.device)
        h, coord = self.conv1(g, feats.to(self.device), coord)
        h = torch.relu(h)
        h, coord = self.conv2(g, h, coord)
        # h is (N,1) → squeeze + activation
        return self.activation(h.squeeze(-1))

    def train(self, g, edge_map, train_idx, y_train, valid_idx, y_valid, verbose=False):
        loss_fn = nn.MSELoss()
        valid_losses = []

        for epoch in range(self.n_steps):
            self.optimizer.zero_grad()
            out = self.forward(g, g.ndata['feat'])
            loss_t = loss_fn(out[train_idx], y_train.to(self.device))
            loss_v = loss_fn(out[valid_idx], y_valid.to(self.device))
            loss_t.backward()
            self.optimizer.step()

            valid_losses.append(loss_v.item())
            if verbose and epoch % 1000 == 0:
                print(f"epoch {epoch} — train {loss_t.item():.4f}, valid {loss_v.item():.4f}")

            # early stopping on valid
            if epoch > self.early_stop and valid_losses[-1] > np.mean(valid_losses[-(self.early_stop+1):-1]):
                if verbose: print("Early stopping.")
                break
        return


